import copy
import io
import logging
import os
import re
import tempfile
from io import BytesIO
from typing import Dict, List, Optional

import fitz
import requests
from datasets import Dataset, Features, Sequence, Value
from tqdm import tqdm


def create_url_dataset(urls: List[str]) -> Dataset:
    """
    Create a dataset from a list of URLs.

    Args:
        urls (List[str]): List of PDF URLs

    Returns:
        Dataset: HuggingFace dataset with columns:
            - url: source URL
    """
    return Dataset.from_dict({"url": urls})


def get_all_page_counts(dataset: Dataset) -> Dataset:

    def f(x):
        try:
            page_count = get_pdf_page_count(x["url"])
            x["page_count"] = page_count
        except Exception as e:
            print(f"Error getting page count for {x['url']}: {str(e)}")
            x["page_count"] = 0
        return x

    dataset = dataset.map(f)
    return dataset


def download_pdfs(dataset: Dataset) -> Dataset:
    """
    Download PDFs from URLs and add serialized content to dataset.
    Uses parallel processing via dataset.map()

    Args:
        dataset (Dataset): Dataset containing URLs in 'url' column

    Returns:
        Dataset: Original dataset with additional columns:
            - pdf_bytes: serialized PDF content
            - filename: extracted filename from URL
            - success: whether download succeeded
    """

    def download_single_pdf(example):
        url = example["url"]
        # Extract filename from URL
        filename = url.split("/")[-1]
        if not filename.lower().endswith(".pdf"):
            filename += ".pdf"

        # Download PDF with progress tracking disabled for parallel processing
        headers = {}  # Add any required headers here
        response = requests.get(url, headers=headers, stream=True)

        if response.status_code == 200:
            # Get the PDF content as bytes
            pdf_bytes = response.content

            return {
                **example,
                "pdf_bytes": pdf_bytes,
                "filename": filename,
                "success": True,
            }
        return {**example, "pdf_bytes": "", "filename": filename, "success": False}

    # Use map to process URLs in parallel
    # num_proc can be adjusted based on available resources
    dataset = dataset.map(download_single_pdf, num_proc=32, desc="Downloading PDFs")

    return dataset


def expand_and_extract_pages(dataset: Dataset, page_count_column: str) -> Dataset:
    """
    Expands a dataset by creating new rows for each page number and extracts
    individual pages from PDFs. Expects dataset to have 'pdf_bytes' column
    and the specified page count column.

    Args:
        dataset (Dataset): Dataset containing PDFs
        page_count_column (str): Name of the column containing the page count

    Returns:
        Dataset: Expanded dataset with new rows for each page and 'page_bytes' column
                containing extracted pages
    """
    expanded_data = []

    # Process each PDF document
    for example in dataset:
        if not example.get("success", True) or example["pdf_bytes"] is None:
            continue

        try:
            # Open PDF from bytes once per document
            pdf_stream = BytesIO(example["pdf_bytes"])
            src_pdf = fitz.open(stream=pdf_stream, filetype="pdf")
            page_count = int(example[page_count_column])

            # Create a new row for each page
            # TODO - change back to 1.
            for page_num in range(1, page_count + 1, 1):
                example_copy = copy.deepcopy(example)
                example_copy["pdf_bytes"] = "REMOVED"
                example_copy["page_number"] = page_num
                # Create new PDF with single page
                dst_pdf = fitz.open()

                # TODO - check two papers at a time, will lead to dupes but we can filter them out. This lets solutions be across pages.
                dst_pdf.insert_pdf(
                    src_pdf, from_page=page_num - 1, to_page=page_num - 0
                )

                # Save to bytes
                output_stream = BytesIO()
                dst_pdf.save(output_stream)
                example_copy["page_bytes"] = output_stream.getvalue()

                # Clean up
                dst_pdf.close()

                expanded_data.append(example_copy)

            # Clean up the source PDF
            src_pdf.close()

        except Exception as e:
            print(
                f"Error processing PDF from {example.get('url', 'unknown URL')}: {str(e)}"
            )
            continue

    # Create new dataset from expanded data
    return Dataset.from_list(expanded_data)


def expand_by_pages(dataset, page_count_column):
    """
    Expands a dataset by creating new rows for each page number.

    Args:
        dataset (Dataset): Input Hugging Face dataset
        page_count_column (str): Name of the column containing the page count

    Returns:
        Dataset: Expanded dataset with new rows for each page and 'page_bytes' column
                containing extracted pages
    """
    expanded_data = {k: [] for k in dataset.features.keys()}
    expanded_data = []  # Add new column for page numbers

    for item in tqdm(url_dataset, desc="Downloading PDFs"):
        url = item["url"]
        filename = url.split("/")[-1]

        # Create a new row for each page
        for page_num in range(1, page_count + 1):
            example_copy = copy.deepcopy(page_count)
            example_copy["page_number"] = page_num
            # Add page number
            expanded_data.append(example_copy)

    # Create new dataset from expanded data
    return Dataset.from_list(expanded_data)


def get_pdf_page_count(url):
    headers = {}
    response = requests.get(url, headers=headers, stream=True)

    # Save to a temporary file
    pdf_data = BytesIO(response.content)
    pdf_document = fitz.open(stream=pdf_data, filetype="pdf")
    page_count = pdf_document.page_count

    # TODO - nodes will OOM if the text is too long
    return min(page_count, 500)
    # return page_count
